{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/workflow/multi_step_query_engine.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MultiStep Query Engine\n",
    "\n",
    "The `MultiStepQueryEngine` breaks down a complex query into sequential sub-questions.\n",
    "\n",
    "To answer the query: In which city did the author found his first company, Viaweb?, we need to answer the following sub-questions sequentially:\n",
    "\n",
    "1. Who is the author that founded his first company, Viaweb?\n",
    "2. In which city did Paul Graham found his first company, Viaweb?\n",
    "\n",
    "As an example, the answer from each step (sub-query-1) is used to generate the next step's question (sub-query-2), with steps created sequentially rather than all at once.\n",
    "\n",
    "In this notebook, we will implement the same with [MultiStepQueryEngine](https://docs.llamaindex.ai/en/stable/examples/query_transformations/simpleindexdemo-multistep/) using workflows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -U llama-index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since workflows are async first, this all runs fine in a notebook. If you were running in your own code, you would want to use `asyncio.run()` to start an async event loop if one isn't already running.\n",
    "\n",
    "```python\n",
    "async def main():\n",
    "    <async code>\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    import asyncio\n",
    "    asyncio.run(main())\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Workflow\n",
    "\n",
    "## Designing the Workflow\n",
    "\n",
    "MultiStepQueryEngine consists of some clearly defined steps\n",
    "1. Indexing data, creating an index.\n",
    "2. Create multiple sub-queries to answer the query.\n",
    "3. Synthesize the final response\n",
    "\n",
    "With this in mind, we can create events and workflow steps to follow this process!\n",
    "\n",
    "### The Workflow Event\n",
    "\n",
    "To handle these steps, we need to define `QueryMultiStepEvent`\n",
    "\n",
    "The other steps will use the built-in `StartEvent` and `StopEvent` events."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Event"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt_tab to\n",
      "[nltk_data]     /Users/ravithejad/Desktop/llamaindex/lib/python3.9/sit\n",
      "[nltk_data]     e-packages/llama_index/core/_static/nltk_cache...\n",
      "[nltk_data]   Package punkt_tab is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "from llama_index.core.workflow import Event\n",
    "from typing import Dict, List, Any\n",
    "from llama_index.core.schema import NodeWithScore\n",
    "\n",
    "\n",
    "class QueryMultiStepEvent(Event):\n",
    "    \"\"\"\n",
    "    Event containing results of a multi-step query process.\n",
    "\n",
    "    Attributes:\n",
    "        nodes (List[NodeWithScore]): List of nodes with their associated scores.\n",
    "        source_nodes (List[NodeWithScore]): List of source nodes with their scores.\n",
    "        final_response_metadata (Dict[str, Any]): Metadata associated with the final response.\n",
    "    \"\"\"\n",
    "\n",
    "    nodes: List[NodeWithScore]\n",
    "    source_nodes: List[NodeWithScore]\n",
    "    final_response_metadata: Dict[str, Any]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Workflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.indices.query.query_transform.base import (\n",
    "    StepDecomposeQueryTransform,\n",
    ")\n",
    "from llama_index.core.response_synthesizers import (\n",
    "    get_response_synthesizer,\n",
    ")\n",
    "\n",
    "from llama_index.core.schema import QueryBundle, TextNode\n",
    "\n",
    "from llama_index.core.workflow import (\n",
    "    Context,\n",
    "    Workflow,\n",
    "    StartEvent,\n",
    "    StopEvent,\n",
    "    step,\n",
    ")\n",
    "\n",
    "from llama_index.core import Settings\n",
    "from llama_index.core.llms import LLM\n",
    "\n",
    "from typing import cast\n",
    "from IPython.display import Markdown, display\n",
    "\n",
    "\n",
    "class MultiStepQueryEngineWorkflow(Workflow):\n",
    "    def combine_queries(\n",
    "        self,\n",
    "        query_bundle: QueryBundle,\n",
    "        prev_reasoning: str,\n",
    "        index_summary: str,\n",
    "        llm: LLM,\n",
    "    ) -> QueryBundle:\n",
    "        \"\"\"Combine queries using StepDecomposeQueryTransform.\"\"\"\n",
    "        transform_metadata = {\n",
    "            \"prev_reasoning\": prev_reasoning,\n",
    "            \"index_summary\": index_summary,\n",
    "        }\n",
    "        return StepDecomposeQueryTransform(llm=llm)(\n",
    "            query_bundle, metadata=transform_metadata\n",
    "        )\n",
    "\n",
    "    def default_stop_fn(self, stop_dict: Dict) -> bool:\n",
    "        \"\"\"Stop function for multi-step query combiner.\"\"\"\n",
    "        query_bundle = cast(QueryBundle, stop_dict.get(\"query_bundle\"))\n",
    "        if query_bundle is None:\n",
    "            raise ValueError(\"Response must be provided to stop function.\")\n",
    "\n",
    "        return \"none\" in query_bundle.query_str.lower()\n",
    "\n",
    "    @step\n",
    "    async def query_multistep(\n",
    "        self, ctx: Context, ev: StartEvent\n",
    "    ) -> QueryMultiStepEvent:\n",
    "        \"\"\"Execute multi-step query process.\"\"\"\n",
    "        prev_reasoning = \"\"\n",
    "        cur_response = None\n",
    "        should_stop = False\n",
    "        cur_steps = 0\n",
    "\n",
    "        # use response\n",
    "        final_response_metadata: Dict[str, Any] = {\"sub_qa\": []}\n",
    "\n",
    "        text_chunks = []\n",
    "        source_nodes = []\n",
    "\n",
    "        query = ev.get(\"query\")\n",
    "        await ctx.store.set(\"query\", ev.get(\"query\"))\n",
    "\n",
    "        llm = Settings.llm\n",
    "        stop_fn = self.default_stop_fn\n",
    "\n",
    "        num_steps = ev.get(\"num_steps\")\n",
    "        query_engine = ev.get(\"query_engine\")\n",
    "        index_summary = ev.get(\"index_summary\")\n",
    "\n",
    "        while not should_stop:\n",
    "            if num_steps is not None and cur_steps >= num_steps:\n",
    "                should_stop = True\n",
    "                break\n",
    "            elif should_stop:\n",
    "                break\n",
    "\n",
    "            updated_query_bundle = self.combine_queries(\n",
    "                QueryBundle(query_str=query),\n",
    "                prev_reasoning,\n",
    "                index_summary,\n",
    "                llm,\n",
    "            )\n",
    "\n",
    "            print(\n",
    "                f\"Created query for the step - {cur_steps} is: {updated_query_bundle}\"\n",
    "            )\n",
    "\n",
    "            stop_dict = {\"query_bundle\": updated_query_bundle}\n",
    "            if stop_fn(stop_dict):\n",
    "                should_stop = True\n",
    "                break\n",
    "\n",
    "            cur_response = query_engine.query(updated_query_bundle)\n",
    "\n",
    "            # append to response builder\n",
    "            cur_qa_text = (\n",
    "                f\"\\nQuestion: {updated_query_bundle.query_str}\\n\"\n",
    "                f\"Answer: {cur_response!s}\"\n",
    "            )\n",
    "            text_chunks.append(cur_qa_text)\n",
    "            for source_node in cur_response.source_nodes:\n",
    "                source_nodes.append(source_node)\n",
    "            # update metadata\n",
    "            final_response_metadata[\"sub_qa\"].append(\n",
    "                (updated_query_bundle.query_str, cur_response)\n",
    "            )\n",
    "\n",
    "            prev_reasoning += (\n",
    "                f\"- {updated_query_bundle.query_str}\\n\" f\"- {cur_response!s}\\n\"\n",
    "            )\n",
    "            cur_steps += 1\n",
    "\n",
    "        nodes = [\n",
    "            NodeWithScore(node=TextNode(text=text_chunk))\n",
    "            for text_chunk in text_chunks\n",
    "        ]\n",
    "        return QueryMultiStepEvent(\n",
    "            nodes=nodes,\n",
    "            source_nodes=source_nodes,\n",
    "            final_response_metadata=final_response_metadata,\n",
    "        )\n",
    "\n",
    "    @step\n",
    "    async def synthesize(\n",
    "        self, ctx: Context, ev: QueryMultiStepEvent\n",
    "    ) -> StopEvent:\n",
    "        \"\"\"Synthesize the response.\"\"\"\n",
    "        response_synthesizer = get_response_synthesizer()\n",
    "        query = await ctx.store.get(\"query\", default=None)\n",
    "        final_response = await response_synthesizer.asynthesize(\n",
    "            query=query,\n",
    "            nodes=ev.nodes,\n",
    "            additional_source_nodes=ev.source_nodes,\n",
    "        )\n",
    "        final_response.metadata = ev.final_response_metadata\n",
    "\n",
    "        return StopEvent(result=final_response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-08-26 14:16:04--  https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8000::154, 2606:50c0:8002::154, 2606:50c0:8001::154, ...\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8000::154|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 75042 (73K) [text/plain]\n",
      "Saving to: ‘data/paul_graham/paul_graham_essay.txt’\n",
      "\n",
      "data/paul_graham/pa 100%[===================>]  73.28K  --.-KB/s    in 0.01s   \n",
      "\n",
      "2024-08-26 14:16:04 (6.91 MB/s) - ‘data/paul_graham/paul_graham_essay.txt’ saved [75042/75042]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!mkdir -p 'data/paul_graham/'\n",
    "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import SimpleDirectoryReader\n",
    "\n",
    "documents = SimpleDirectoryReader(\"data/paul_graham\").load_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.llms.openai import OpenAI\n",
    "\n",
    "llm = OpenAI(model=\"gpt-4\")\n",
    "\n",
    "Settings.llm = llm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Index and QueryEngine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import VectorStoreIndex\n",
    "\n",
    "index = VectorStoreIndex.from_documents(\n",
    "    documents=documents,\n",
    ")\n",
    "\n",
    "query_engine = index.as_query_engine()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the Workflow!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = MultiStepQueryEngineWorkflow(timeout=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set the parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sets maximum number of steps taken to answer the query.\n",
    "num_steps = 3\n",
    "\n",
    "# Set summary of the index, useful to create modified query at each step.\n",
    "index_summary = \"Used to answer questions about the author\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test with a query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"In which city did the author found his first company, Viaweb?\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Created query for the step - 0 is: Who is the author who founded Viaweb?\n",
      "Created query for the step - 1 is: In which city did Paul Graham found his first company, Viaweb?\n",
      "Created query for the step - 2 is: None\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "> Question: In which city did the author found his first company, Viaweb?"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/markdown": [
       "Answer: The author founded his first company, Viaweb, in Cambridge."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "result = await w.run(\n",
    "    query=query,\n",
    "    query_engine=query_engine,\n",
    "    index_summary=index_summary,\n",
    "    num_steps=num_steps,\n",
    ")\n",
    "\n",
    "# If created query in a step is None, the process will be stopped.\n",
    "\n",
    "display(\n",
    "    Markdown(\"> Question: {}\".format(query)),\n",
    "    Markdown(\"Answer: {}\".format(result)),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Display step-queries created"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "[('Who is the author who founded Viaweb?', 'The author who founded Viaweb is Paul Graham.'), ('In which city did Paul Graham found his first company, Viaweb?', 'Paul Graham founded his first company, Viaweb, in Cambridge.')]"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "sub_qa = result.metadata[\"sub_qa\"]\n",
    "tuples = [(t[0], t[1].response) for t in sub_qa]\n",
    "display(Markdown(f\"{tuples}\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama-index-caVs7DDe-py3.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
